{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jsonlines\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from utils import read_data, extract_data, partition\n",
    "from evaluation import calculate_score, process_CTC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_list=('CMeIE', 'CHIP-CDN', 'CHIP-CDEE', 'CHIP-MDCFNPC',\n",
    "           'CHIP-CTC', 'KUAKE-QIC',\n",
    "           'IMCS-V2-MRG', 'MedDG',)\n",
    "pred_path = \"pred\"\n",
    "true_path = \"true\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_list = [\"test_predications.json\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_dict = {target: [] for target in target_list}\n",
    "label_dict = {}\n",
    "\n",
    "for target in target_list:\n",
    "    target_path = os.path.join(pred_path, target)\n",
    "    \n",
    "    if not os.path.exists(os.path.join(target_path, task_list[0])):  # needs partition\n",
    "        # all_data = read_data(os.path.join(target_path, \"test_predictions.json\"))\n",
    "        all_data = read_data(target_path)\n",
    "        partition(extract_data(all_data), task_list, target_path)\n",
    "    \n",
    "    for task in task_list:\n",
    "\n",
    "        pp = os.path.join(target_path, task)\n",
    "        tp = os.path.join(true_path, task)\n",
    "        pp = os.path.join(pp, \"test_predictions.json\")\n",
    "        tp = os.path.join(tp, \"test.json\")\n",
    "\n",
    "        if task == \"CHIP-CTC\":  # CTC needs post process\n",
    "            post_process_function = process_CTC\n",
    "        else:\n",
    "            post_process_function = None\n",
    "\n",
    "        score, labels, _ = calculate_score(task, pp, tp, post_process_function)\n",
    "        score_dict[target].append(score)\n",
    "        label_dict[task] = labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#上面的代码只把test_prediction拆开放在各个任务中，没有把prediction放进去\n",
    "target_path = os.path.join(pred_path, \"test.json\")   #pred/test.json\n",
    "\n",
    "all_data = read_data(target_path)\n",
    "partition(extract_data(all_data), task_list, pred_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_data, res_key = [], []\n",
    "for key, value in score_dict.items():\n",
    "    res_data.append(value)\n",
    "    res_key.append(key)\n",
    "\n",
    "res_df = pd.DataFrame(columns=task_list,\n",
    "                      index=res_key,\n",
    "                      data=res_data)\n",
    "res_df[\"average\"] = res_df.mean(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df.head(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    new_res_df = res_df.drop(columns=[\"CHIP-STS\", \"KUAKE-IR\", \"average\"])\n",
    "except:\n",
    "    new_res_df = res_df\n",
    "new_res_df[\"average\"] = new_res_df.mean(axis=1)\n",
    "new_res_df.head(50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for st in [\"CHIP-CDN\", \"CHIP-MDCFNPC\", \"IMCS-V2-MRG\", \"KUAKE-QIC\"]: \n",
    "    score, _, _ = calculate_score(st, \n",
    "                                  \"pred/moelora/%s/test_predictions.json\" % st, \n",
    "                                  \"true/%s/dev.json\" %st, \n",
    "                                  post_process_function)\n",
    "    print(\"The score for task %s is: %.5f\" % (st, score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
